load("//bazel:api.bzl", "lit_tests", "mojo_filecheck_test", "mojo_test")

_LIT_TESTS = [
    "test_accelerator_arch_cli.mojo",
    "test_accelerator_arch_cli_kernels.mojo",
]

_FILECHECK_TESTS = [
    "test_cluster_dims_attribute.mojo",
    "test_cluster_sync.mojo",
    "test_debug_assert_gpu_error.mojo",
    "test_grid_dependence_control.mojo",
    "test_kernel_with_list.mojo",
    "test_print.mojo",
    "test_printf.mojo",
    "test_producer_consumer.mojo",
    "test_sync.mojo",
    "test_tbc_launch_config.mojo",
    "test_tensor_core_fp8_mma_sync.mojo",
]

_EXTRA_COPTS = {
    "test_grid_dependence_control.mojo": [
        "-D",
        "PDL_LEVEL=1",
    ],
}

_EXTRA_CONSTRAINTS = {
    "test_amd_format.mojo": ["//:amd_gpu"],  # FIXME: KERN-1377:["//:amd_gpu"], disabled
    "test_cluster_dims_attribute.mojo": ["//:h100_gpu"],
    "test_cluster_sync.mojo": ["//:h100_gpu"],
    "test_semaphore.mojo": ["//:nvidia_gpu"],
    "test_constant_memory.mojo": ["//:nvidia_gpu"],  # FIXME: KERN-1377
    "test_grid_dependence_control.mojo": ["//:h100_gpu"],
    "test_has_amd_gpu.mojo": ["//:amd_gpu"],
    "test_has_nvidia_gpu.mojo": ["//:nvidia_gpu"],
    "test_kernel_with_list.mojo": ["//:nvidia_gpu"],  # FIXME: KERN-1377 remove this
    "test_producer_consumer.mojo": ["//:h100_gpu"],
    "test_tbc_launch_config.mojo": ["//:h100_gpu"],
    "test_tensor_core_fp8_mma_sync.mojo": ["//:h100_gpu"],
    "test_wmma.mojo": ["//:nvidia_gpu"],
    "test_wmma_amd.mojo": ["//:amd_gpu"],
    "test_convert.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_cast_roundtrip.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_debug_assert_gpu_error.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2405
    "test_init_vector_gpu.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2415
    "test_print_elementwise.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2366
    "test_prefix_sum.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2415
    "test_print.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2405
    "test_capture_trait_type.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2405
    "test_shuffle.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_printf.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2405
    "test_fast_div.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2366
    "test_sum.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2397
    "test_random.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2397
}

[
    mojo_test(
        name = src + ".test",
        srcs = [src],
        copts = _EXTRA_COPTS.get(src, []),
        tags = ["gpu"],
        target_compatible_with = ["//:has_gpu"] + _EXTRA_CONSTRAINTS.get(src, []),
        deps = [
            "//max:internal_utils",
            "//max:kv_cache",
            "//max:linalg",
            "//max:nn",
            "//max:quantization",
            "@mojo//:stdlib",
        ],
    )
    for src in glob(
        ["**/*.mojo"],
        exclude = _LIT_TESTS + _FILECHECK_TESTS,
    )
]

[
    mojo_filecheck_test(
        name = src + ".test",
        srcs = [src],
        copts = _EXTRA_COPTS.get(src, []),
        expect_crash = src == "test_debug_assert_gpu_error.mojo",
        tags = ["gpu"],
        target_compatible_with = ["//:has_gpu"] + _EXTRA_CONSTRAINTS.get(src, []),
        deps = [
            "//max:internal_utils",
            "//max:kv_cache",
            "//max:linalg",
            "//max:nn",
            "//max:quantization",
            "@mojo//:stdlib",
        ],
    )
    for src in _FILECHECK_TESTS
]

lit_tests(
    name = "lit",
    size = "large",
    srcs = _LIT_TESTS,
    gpu_constraints = ["//:has_gpu"],
    mojo_deps = [
        "//max:internal_utils",
        "//max:kv_cache",
        "//max:linalg",
        "//max:nn",
        "//max:quantization",
        "@mojo//:stdlib",
    ],
    tags = ["gpu"],
)
